Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: use mma PTX instructions for FlashAttention #11583

Merged
merged 6 commits into from
Feb 2, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

This PR replaces the WMMA-based CUDA FlashAttention kernel with a kernel that instead uses PTX instructions to access tensor cores. These kernels are typically used for batch sizes >> 1 but also for batched inference. The principle of the new kernel is the same as with MMQ: in mmq.cuh there are primitives that expose otherwise inaccessible PTX instructions to CUDA code. The primitives are similar to the WMMA interface but they have a well-defined data layout which allows for better optimization.

The data layout is the same for FP16 and int8 and from Turing onward. I replaced INT8_MMA_AVAILABLE with NEW_MMA_AVAILABLE to better reflect this. The tensor cores on Volta are not compatible for the new code, for V100s the old code is used. Long-term I plan to purchase a V100 and write a dedicated kernel. There was interest from @IMbackK regarding the use of AMD tensor cores; if they could be made to fit the interface in mmq.cuh they could in principle work but to me this seems to be an unlikely prospect.

t/s by batch size
GPU Model Microbatch size Test t/s master t/s 60958f6 Speedup
RTX 3090 gemma 2B F16 8 pp16384 917.94 883.48 0.96
RTX 3090 gemma 2B F16 16 pp16384 1783.82 1717.41 0.96
RTX 3090 gemma 2B F16 32 pp16384 3371.60 2946.57 0.87
RTX 3090 gemma 2B F16 64 pp16384 5460.72 3931.72 0.72
RTX 3090 gemma 2B F16 128 pp16384 7397.50 7546.77 1.02
RTX 3090 gemma 2B F16 256 pp16384 8118.48 9925.36 1.22
RTX 3090 gemma 2B F16 512 pp16384 8378.34 11162.57 1.33
RTX 3090 gemma 2B F16 1024 pp16384 8825.44 11560.53 1.31
RTX 3090 gemma 2B F16 2048 pp16384 9058.17 11861.19 1.31
RTX 3090 gemma 2B F16 4096 pp16384 9066.72 11855.52 1.31
RTX 3090 gemma 2B F16 8192 pp16384 9062.71 11880.34 1.31
RTX 3090 gemma 2B F16 16384 pp16384 9059.61 11881.52 1.31
RTX 3090 llama 8B Q4_0 8 pp16384 476.90 423.76 0.89
RTX 3090 llama 8B Q4_0 16 pp16384 1000.21 852.93 0.85
RTX 3090 llama 8B Q4_0 32 pp16384 1481.18 1463.63 0.99
RTX 3090 llama 8B Q4_0 64 pp16384 2282.17 2249.85 0.99
RTX 3090 llama 8B Q4_0 128 pp16384 2699.08 2872.71 1.06
RTX 3090 llama 8B Q4_0 256 pp16384 2961.71 3283.12 1.11
RTX 3090 llama 8B Q4_0 512 pp16384 3141.05 3433.75 1.09
RTX 3090 llama 8B Q4_0 1024 pp16384 3258.08 3515.12 1.08
RTX 3090 llama 8B Q4_0 2048 pp16384 3209.13 3486.96 1.09
RTX 3090 llama 8B Q4_0 4096 pp16384 3208.47 3489.13 1.09
RTX 3090 llama 8B Q4_0 8192 pp16384 3211.33 3490.26 1.09
RTX 3090 llama 8B Q4_0 16384 pp16384 3211.22 3493.71 1.09
RTX 3090 phi2 3B F16 8 pp16384 578.06 535.17 0.93
RTX 3090 phi2 3B F16 16 pp16384 1125.97 1042.01 0.93
RTX 3090 phi2 3B F16 32 pp16384 1989.49 1945.43 0.98
RTX 3090 phi2 3B F16 64 pp16384 2915.99 3268.04 1.12
RTX 3090 phi2 3B F16 128 pp16384 4020.39 4724.74 1.18
RTX 3090 phi2 3B F16 256 pp16384 4519.19 5469.01 1.21
RTX 3090 phi2 3B F16 512 pp16384 5079.62 5849.03 1.15
RTX 3090 phi2 3B F16 1024 pp16384 5231.56 5731.99 1.10
RTX 3090 phi2 3B F16 2048 pp16384 5175.37 5708.27 1.10
RTX 3090 phi2 3B F16 4096 pp16384 5175.80 5712.36 1.10
RTX 3090 phi2 3B F16 8192 pp16384 5180.30 5713.81 1.10
RTX 3090 phi2 3B F16 16384 pp16384 5182.75 5716.28 1.10
RTX 4090 gemma 2B F16 8 pp16384 1171.05 1207.75 1.03
RTX 4090 gemma 2B F16 16 pp16384 2293.91 2442.94 1.06
RTX 4090 gemma 2B F16 32 pp16384 4501.31 4703.87 1.05
RTX 4090 gemma 2B F16 64 pp16384 7640.10 7877.96 1.03
RTX 4090 gemma 2B F16 128 pp16384 11600.18 13510.96 1.16
RTX 4090 gemma 2B F16 256 pp16384 15288.96 20374.54 1.33
RTX 4090 gemma 2B F16 512 pp16384 17760.10 25777.80 1.45
RTX 4090 gemma 2B F16 1024 pp16384 16882.25 25036.06 1.48
RTX 4090 gemma 2B F16 2048 pp16384 16277.56 24266.67 1.49
RTX 4090 gemma 2B F16 4096 pp16384 16274.46 24295.69 1.49
RTX 4090 gemma 2B F16 8192 pp16384 16262.72 24302.73 1.49
RTX 4090 gemma 2B F16 16384 pp16384 16260.68 24365.30 1.50
RTX 4090 llama 8B Q4_0 8 pp16384 900.71 884.15 0.98
RTX 4090 llama 8B Q4_0 16 pp16384 1568.39 1586.71 1.01
RTX 4090 llama 8B Q4_0 32 pp16384 2541.87 2818.84 1.11
RTX 4090 llama 8B Q4_0 64 pp16384 4238.90 4923.51 1.16
RTX 4090 llama 8B Q4_0 128 pp16384 5394.40 6753.02 1.25
RTX 4090 llama 8B Q4_0 256 pp16384 7308.41 8470.81 1.16
RTX 4090 llama 8B Q4_0 512 pp16384 7759.90 9103.17 1.17
RTX 4090 llama 8B Q4_0 1024 pp16384 7695.01 9156.21 1.19
RTX 4090 llama 8B Q4_0 2048 pp16384 7369.96 8807.45 1.20
RTX 4090 llama 8B Q4_0 4096 pp16384 7360.98 8801.88 1.20
RTX 4090 llama 8B Q4_0 8192 pp16384 7373.74 8803.42 1.19
RTX 4090 llama 8B Q4_0 16384 pp16384 7362.43 8810.38 1.20
RTX 4090 phi2 3B F16 8 pp16384 707.18 691.02 0.98
RTX 4090 phi2 3B F16 16 pp16384 1384.09 1364.58 0.99
RTX 4090 phi2 3B F16 32 pp16384 2646.27 2678.72 1.01
RTX 4090 phi2 3B F16 64 pp16384 4210.05 5014.03 1.19
RTX 4090 phi2 3B F16 128 pp16384 6252.95 8428.48 1.35
RTX 4090 phi2 3B F16 256 pp16384 9784.33 12150.60 1.24
RTX 4090 phi2 3B F16 512 pp16384 10954.86 13642.19 1.25
RTX 4090 phi2 3B F16 1024 pp16384 12024.48 15077.97 1.25
RTX 4090 phi2 3B F16 2048 pp16384 11047.99 13698.47 1.24
RTX 4090 phi2 3B F16 4096 pp16384 11066.46 13689.34 1.24
RTX 4090 phi2 3B F16 8192 pp16384 11074.76 13689.12 1.24
RTX 4090 phi2 3B F16 16384 pp16384 11058.57 13705.94 1.24
t/s by prompt length
GPU Model Microbatch size Test t/s master t/s 60958f6 Speedup
RTX 4090, RTX 4090 gemma 2B F16 512 pp512 28125.95 28839.76 1.03
RTX 4090, RTX 4090 gemma 2B F16 512 pp1024 37059.39 37877.02 1.02
RTX 4090, RTX 4090 gemma 2B F16 512 pp2048 42284.50 43333.59 1.02
RTX 4090, RTX 4090 gemma 2B F16 512 pp4096 42934.60 45718.06 1.06
RTX 4090, RTX 4090 gemma 2B F16 512 pp8192 38007.78 47298.68 1.24
RTX 4090, RTX 4090 gemma 2B F16 512 pp16384 29515.14 34328.45 1.16
RTX 4090, RTX 4090 gemma 2B F16 512 pp32768 20327.29 21018.23 1.03
RTX 4090, RTX 4090 gemma 2B F16 512 pp65536 11536.28 11717.40 1.02
RTX 4090, RTX 4090 gemma 2B F16 512 pp131072 5614.87 5601.61 1.00
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp512 11991.52 12146.32 1.01
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp1024 15425.90 15996.82 1.04
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp2048 18340.19 19053.39 1.04
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp4096 18503.14 19691.59 1.06
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp8192 16884.28 18732.80 1.11
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp16384 13785.88 16109.49 1.17
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp32768 9923.22 12390.37 1.25
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp65536 6336.18 8377.22 1.32
RTX 4090, RTX 4090 llama 8B Q4_0 512 pp131072 3678.63 4909.84 1.33
RTX 4090, RTX 4090 phi2 3B F16 512 pp512 18937.05 19018.88 1.00
RTX 4090, RTX 4090 phi2 3B F16 512 pp1024 24509.45 25041.30 1.02
RTX 4090, RTX 4090 phi2 3B F16 512 pp2048 27943.38 29823.23 1.07
RTX 4090, RTX 4090 phi2 3B F16 512 pp4096 27902.36 30910.38 1.11
RTX 4090, RTX 4090 phi2 3B F16 512 pp8192 24938.07 29083.45 1.17
RTX 4090, RTX 4090 phi2 3B F16 512 pp16384 19545.08 24456.94 1.25
RTX 4090, RTX 4090 phi2 3B F16 512 pp32768 13327.44 17485.20 1.31
RTX 4090, RTX 4090 phi2 3B F16 512 pp65536 8084.14 10843.93 1.34

Performance for large batch sizes is good, for small batch sizes it's still suboptimal. Asymptotically the speedup for long prompts is about 1.1-1.5x on the GPUs I've tested. Notably the new kernel uses a stream-k decomposition though so the performance should generalize better beyond the GPUs that I as a dev am optimizing performance for. Also so far I've restricted the implementation to features that are available with Turing, there are some Ampere features that should be useful.

The file size of libggml-cuda.so with GGML_NATIVE=OFF decreases from 363 MB to 358 MB. There is no change in compilation time on my multithreaded system because the compilation is waiting for MMQ.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Feb 1, 2025
@JohannesGaessler
Copy link
Collaborator Author

The CUDA 11.7 compilation problems on Windows seem to be because the movmatrix instruction which transposes an 8x8 tile of 16 bit values needs a higher PTX ISA version than the mma instruction even though both need Turing or newer in terms of hardware. This is very annoying but since this instruction is not performance-critical I think I can write a workaround either with __shfl instructions or by going through shared memory.

@@ -1775,7 +1775,7 @@ extern "C" {
struct ggml_tensor * a,
int k);

#define GGML_KQ_MASK_PAD 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor note for the future: we should hide this constant behind an API call

@sorasoras
Copy link

The CUDA 11.7 compilation problems on Windows seem to be because the movmatrix instruction which transposes an 8x8 tile of 16 bit values needs a higher PTX ISA version than the mma instruction even though both need Turing or newer in terms of hardware. This is very annoying but since this instruction is not performance-critical I think I can write a workaround either with __shfl instructions or by going through shared memory.

Can there be a ptx kernel for Pascal via Dp4a?

@JohannesGaessler
Copy link
Collaborator Author

For Pascal I am not aware of any useful PTX instructions that are not already being used in CUDA code.

As of right now quantized KV caches are converted to floating point numbers for large batch FlashAttention. It would in principle be possible to write a FlashAttention kernel that uses int8 arithmetic instead. This would very likely be faster on Pascal, on Turing or newer it's a bit unclear since it's difficult to get good int8 tensor core utilization with the GGML quantization formats.

@JohannesGaessler
Copy link
Collaborator Author

I pushed a version with a workaround for movmatrix. The difference in the total kernel runtime is ~1%, for the end-to-end runtime it should be negligible. I don't know how to query the max. supported PTX version from CUDA code, I made the switch based on CUDA 11 vs. 12; I'll ask about a better way to do this on Friday when I talk to NVIDIA engineers.

@JohannesGaessler
Copy link
Collaborator Author

While going through the PTX ISA documentation I noticed that there is a table that maps PTX versions to CUDA versions. Using that table I could determine that movmatrix should be available starting with CUDA 11.8 and I changed the check accordingly.

JohannesGaessler and others added 2 commits February 2, 2025 18:29
Co-authored-by: Diego Devesa <slarengh@gmail.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
@JohannesGaessler JohannesGaessler merged commit 864a0b6 into ggml-org:master Feb 2, 2025
48 checks passed
Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple issues for this to work on amd mfma (and lets forget about wmma), for one the shapes are quite non-ideal, all amd shapes are square and 8 is not among them (4x4xN, 16x16xN, 32x32xN, 4x4 ofc being the least efficient), the output layout is different (and has its own private register space we would have to move it out of), maybe you can make it perform better than the vector implementation but it wont be ideal at all. Its not something i will be trying to tackle soon, as there are lower hanging fruit.

I wish we could keep the wmma implementation, as this performs decently on mfma/rocmwmma for large batch sizes

@@ -25,6 +25,7 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synchronous version of the shuffle operators has been supported since rocm 6.2, probably we should just add the ability to conditionally use them based on HIP_VERSION >= 60200000 even when this wont do anything for now, just so it do sent get forgotten in the future.

@JohannesGaessler
Copy link
Collaborator Author

I wish we could keep the wmma implementation, as this performs decently on mfma/rocmwmma for large batch sizes

For CUDA the implementation is definitely bad and in a vacuum I would not want to invest the effort to maintain it long-term. However, I recognize that you're making valuable contributions to the project and I'm willing to in return maintain the wmma code. I would also be willing to merge an optional compilation argument for rocwmma if I get a pledge from you that you will maintain it (since I currently do not have any suitable hardware for testing, may change in March).

But I must stress that relative to what the hardware would be capable of the performance of a HIP port of my wmma code will always be bad. Ideally we would just have a ROCm implementation. I would be willing to learn ROCm in order to review any related code.

@IMbackK
Copy link
Collaborator

IMbackK commented Feb 2, 2025

Dont get me wrong, its bad, it just performs better than the other attn implementations available in ggml for now. Lets discuss details at a later date, as for now as i understand it its in no intimidate risk of being removed, as it still benefits volta also.

@ggerganov
Copy link
Member

@JohannesGaessler After this change, the ggml-ci node that runs the CUDA builds is failing because it uses V100:

https://github.com/ggml-org/ci/blob/results/llama.cpp/86/4a0b67a6c8f648c43ce8271f9cb2e12dd5df6e/ggml-4-x86-cuda-v100/stdall#L12402

How to modify the ggml-ci script to fix this?

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Feb 4, 2025
* CUDA: use mma PTX instructions for FlashAttention

* __shfl_sync workaround for movmatrix

* add __shfl_sync to HIP

Authors : Johannes Gaessler and Slaren
tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
* CUDA: use mma PTX instructions for FlashAttention

* __shfl_sync workaround for movmatrix

* add __shfl_sync to HIP

Co-authored-by: Diego Devesa <slarengh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants